import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from model.net.energy_head import LinearClassifier


class HEM(nn.Module):

    def __init__(self, enc, head):
        super(HEM, self).__init__()
        self.enc = enc
        self.head = head

    def forward(self, x, mcog=None):
        z = self.enc(x).squeeze()
        if mcog is not None:
            output = self.head(z, mcog=mcog)
        else:
            output = self.head(z)
        return output

    def e(self, x, y=None):
        logits = self(x)
        if y is None:
            v = logits.logsumexp(1)
            return v
        else:
            return torch.gather(logits, 1, y[:, None])

    def e_abs(self, x, y=None):
        logits = self(x)
        abs_logits = torch.abs(logits)
        if y is None:
            v = abs_logits.logsumexp(1)
            return v
        else:
            return torch.gather(abs_logits, 1, y[:, None])

    def ec(self, x):
        z = self.enc(x)
        return self.head(z, out_cossim=True)
